Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Nov 22, 2025

Purpose

Expand csrc/quantization/fp4/nvfp4_experts_quant.cu and csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu to build for SM120 (RTX 50xx and RTX PRO 6000 Blackwell) so we can support nvfp4 cutlass moe on the platform.

Thanks to @AichenF in sgl-project/sglang#11737 for the extension to the SM100 kernel.

Hoping to address #29030, #29141

Test Plan

Eval on RTX 5090

Test Result

Main:
Fails to run due to various kernels not built on SM120 but the cutlass path is taken

PR:

Test compressed-tensors checkpoint https://huggingface.co/RedHatAI/Qwen3-30B-A3B-NVFP4

vllm serve RedHatAI/Qwen3-30B-A3B-NVFP4
python tests/evals/gsm8k/gsm8k_eval.py
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:37<00:00, 35.37it/s]

Results:
Accuracy: 0.897
Invalid responses: 0.000
Total latency: 37.318 s
Questions per second: 35.345
Total output tokens: 153114
Output tokens per second: 4102.978

Test modelopt checkpoint https://huggingface.co/nvidia/Qwen3-30B-A3B-NVFP4

vllm serve nvidia/Qwen3-30B-A3B-NVFP4
python tests/evals/gsm8k/gsm8k_eval.py
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:37<00:00, 35.16it/s]

Results:
Accuracy: 0.883
Invalid responses: 0.001
Total latency: 37.537 s
Questions per second: 35.138
Total output tokens: 152699
Output tokens per second: 4067.910

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for NVFP4 MoE kernels on SM120 architecture using CUTLASS. The changes include adding the new kernel file to the build system, implementing the SM120-specific kernel, and creating a dispatcher to select the correct kernel based on the SM version.

The implementation for the SM120 kernel introduces significant code duplication with the existing SM100 kernel. I've left a comment suggesting a refactoring to improve maintainability by using templates to abstract away the architecture-specific details, similar to patterns seen elsewhere in the codebase. Other changes look good.

@mgoin mgoin added kernel moe ready ONLY add when PR is ready to merge/full CI is needed labels Nov 22, 2025
@mgoin mgoin removed the ready ONLY add when PR is ready to merge/full CI is needed label Nov 22, 2025
Signed-off-by: mgoin <[email protected]>
Signed-off-by: mgoin <[email protected]>
@mgoin mgoin marked this pull request as draft November 23, 2025 00:15
Signed-off-by: mgoin <[email protected]>
.
Signed-off-by: mgoin <[email protected]>
@bbrowning
Copy link
Contributor

Debugging some things here locally with trace output enabled:

/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:244    workspace_bytes: 43008                                                                   
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp:422    CAN IMPLEMENT: Ignoring check to can implement because host pr
oblem shape is not available.                                                                                                                                                                                       
                                                                                                                                                                                                                    
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:312  GemmUniversal::initialize() - workspace 0xe904359b1400, stream: null                       
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp:240  to_underlying_arguments():                                      
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp:251  to_underlying_arguments(): Setting persistent grid SM count to 4
8                                                                                                                                                                                                                   
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp:257    WARNING: Arguments do not include a valid max cluster count.  
  For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters.                                                                                                           
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:336    Setting smem size to 89088                                                               
/home/bbrowning/src/vllm/cmake-build-release/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:343    cudaFuncSetAttribute() returned error: invalid resource handle  

I wonder if we're going to end up needed some of the fixes from CUTLASS 4.3.0 such as https://github.com/NVIDIA/cutlass/pull/2789/files#diff-17feae83c50669a768a86d30b7bd9c50feb773c10c38b85b253f15e7462283d0 or https://github.com/NVIDIA/cutlass/pull/2789/files#diff-b743e12813d0ca6abe40a6711013c612bb15f7affeba65c7d7d4392d7e3d406d

@bbrowning
Copy link
Contributor

So, after some digging around, I got this seemingly working by adjusting run_fp4_blockwise_scaled_group_mm_sm120 to not be a template for OutType and instead just hardcoding it to cutlass::bfloat16_t. That's obviously a hack, and I'm not sure why the template is messing things up. I went through quite a few iterations of other random things before getting this new NVFP4 MoE support to actually run on my DGX Spark, so will see if I can reduce those to the smallest set of changes on top of the latest commit here so you can assess whether it's useful or not as a way to get this working.

@mgoin
Copy link
Member Author

mgoin commented Nov 23, 2025

Wow, I couldn't figure out the gemm initialization failure for several hours yesterday. Even CUTLASS debug logs were giving me nothing. If that fixes the issue for me, thank you so much.

@bbrowning
Copy link
Contributor

I tried explicit template instantiation and that did not work. I tried static inline template and non-templated _bf16 and _f16 suffixed wrappers dispatched to based on checking the output data type, and that did not work. So, to test this locally, I just copied the methods, got rid of the template, and dispatched based on the output data type. I don't know whether this is a CUDA 13.0 thing, nvcc bug, or something else. It's a bit out of my typical area of expertise.

Here's the gist with what seems to work for me - https://gist.github.com/bbrowning/fb0aac4970a881eba26384d262cd41c9. That's entirely untested on anything but a DGX Spark and the model RedHatAI/Qwen3-30B-A3B-NVFP4. Perhaps a macro could reduce code duplication there.

@bbrowning
Copy link
Contributor

I don't know what to really expect performance-wise, but this feels like it's taking advantage of the new NVFP4 kernels? The output token/s is higher than I expected...

$ vllm bench serve \
  --backend openai-chat \
  --endpoint /v1/chat/completions \
  --dataset-name random \
  --random-input-len 300 \
  --random-output-len 200 \
  --ignore-eos \
  --num-prompts 100 \
  --model RedHatAI/Qwen3-30B-A3B-NVFP4 \
  --base-url http://localhost:8000

...

============ Serving Benchmark Result ============
Successful requests:                     100       
Failed requests:                         0         
Benchmark duration (s):                  16.80     
Total input tokens:                      30000     
Total generated tokens:                  20000     
Request throughput (req/s):              5.95      
Output token throughput (tok/s):         1190.78   
Peak output token throughput (tok/s):    1505.00   
Peak concurrent requests:                100.00    
Total Token throughput (tok/s):          2976.96   
---------------Time to First Token----------------
Mean TTFT (ms):                          235.67    
Median TTFT (ms):                        239.96    
P99 TTFT (ms):                           245.96    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          83.15     
Median TPOT (ms):                        83.15     
P99 TPOT (ms):                           83.32     
---------------Inter-token Latency----------------
Mean ITL (ms):                           82.73     
Median ITL (ms):                         87.11     
P99 ITL (ms):                            94.02     
==================================================

Signed-off-by: mgoin <[email protected]>
@mergify
Copy link

mergify bot commented Nov 24, 2025

Documentation preview: https://vllm--29242.org.readthedocs.build/en/29242/

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 24, 2025
@mgoin mgoin marked this pull request as ready for review November 24, 2025 18:12
Signed-off-by: mgoin <[email protected]>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 24, 2025
@mergify
Copy link

mergify bot commented Nov 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 24, 2025
@mergify mergify bot removed the needs-rebase label Nov 24, 2025
@bbrowning
Copy link
Contributor

I pulled the latest changes here onto my SM121 and was able to load RedHatAI/Qwen3-30B-A3B-NVFP4 and RedHatAI/Llama-4-Scout-17B-16E-Instruct-NVFP4 (both NVFP4 MoE models) successfully without having to make any other local changes.Thanks!

fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
fusion_args.beta = 0.0f;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the sm100 version need beta set to 0 also?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem to be so, the sm120 kernel is a bit unique as seen from the bf16 issue

Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm not a cutlass expert but it looks close enough to the sm100 version.

Comment on lines +352 to +354
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
Copy link
Contributor

@ElizaWszola ElizaWszola Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be useful to add a similar detailed message to the SM100 kernel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to improve this across the board for our cutlass kernels, so will do in a followup

@ElizaWszola
Copy link
Contributor

lgtm also!

@vllm-bot vllm-bot merged commit e502098 into vllm-project:main Nov 25, 2025
90 of 91 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Nov 25, 2025
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation kernel moe nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants